{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Few-Shot learning with Reptile\n",
    "\n",
    "**Author:** [ADMoreau](https://github.com/ADMoreau)<br>\n",
    "**Date created:** 2020/05/21<br>\n",
    "**Last modified:** 2020/05/30<br>\n",
    "**Description:** Few-shot classification of the Omniglot dataset using Reptile."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "The [Reptile](https://arxiv.org/abs/1803.02999) algorithm was developed by OpenAI to\n",
    "perform model agnostic meta-learning. Specifically, this algorithm was designed to\n",
    "quickly learn to perform new tasks with minimal training (few-shot learning).\n",
    "The algorithm works by performing Stochastic Gradient Descent using the\n",
    "difference between weights trained on a mini-batch of never before seen data and the\n",
    "model weights prior to training over a fixed number of meta-iterations.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import tensorflow_datasets as tfds\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Define the Hyperparameters\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "learning_rate = 0.003\n",
    "meta_step_size = 0.25\n",
    "\n",
    "inner_batch_size = 25\n",
    "eval_batch_size = 25\n",
    "\n",
    "meta_iters = 2000\n",
    "eval_iters = 5\n",
    "inner_iters = 4\n",
    "\n",
    "eval_interval = 1\n",
    "train_shots = 20\n",
    "shots = 5\n",
    "classes = 5\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Prepare the data\n",
    "\n",
    "The [Omniglot dataset](https://github.com/brendenlake/omniglot/) is a dataset of 1,623\n",
    "characters taken from 50 different alphabets, with 20 examples for each character.\n",
    "The 20 samples for each character were drawn online via Amazon's Mechanical Turk. For the\n",
    "few-shot learning task, `k` samples (or \"shots\") are drawn randomly from `n` randomly-chosen\n",
    "classes. These `n` numerical values are used to create a new set of temporary labels to use\n",
    "to test the model's ability to learn a new task given few examples. In other words, if you\n",
    "are training on 5 classes, your new class labels will be either 0, 1, 2, 3, or 4.\n",
    "Omniglot is a great dataset for this task since there are many different classes to draw\n",
    "from, with a reasonable number of samples for each class.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class Dataset:\n",
    "    # This class will facilitate the creation of a few-shot dataset\n",
    "    # from the Omniglot dataset that can be sampled from quickly while also\n",
    "    # allowing to create new labels at the same time.\n",
    "    def __init__(self, training):\n",
    "        # Download the tfrecord files containing the omniglot data and convert to a\n",
    "        # dataset.\n",
    "        split = \"train\" if training else \"test\"\n",
    "        ds = tfds.load(\"omniglot\", split=split, as_supervised=True, shuffle_files=False)\n",
    "        # Iterate over the dataset to get each individual image and its class,\n",
    "        # and put that data into a dictionary.\n",
    "        self.data = {}\n",
    "\n",
    "        def extraction(image, label):\n",
    "            # This function will shrink the Omniglot images to the desired size,\n",
    "            # scale pixel values and convert the RGB image to grayscale\n",
    "            image = tf.image.convert_image_dtype(image, tf.float32)\n",
    "            image = tf.image.rgb_to_grayscale(image)\n",
    "            image = tf.image.resize(image, [28, 28])\n",
    "            return image, label\n",
    "\n",
    "        for image, label in ds.map(extraction):\n",
    "            image = image.numpy()\n",
    "            label = str(label.numpy())\n",
    "            if label not in self.data:\n",
    "                self.data[label] = []\n",
    "            self.data[label].append(image)\n",
    "            self.labels = list(self.data.keys())\n",
    "\n",
    "    def get_mini_dataset(\n",
    "        self, batch_size, repetitions, shots, num_classes, split=False\n",
    "    ):\n",
    "        temp_labels = np.zeros(shape=(num_classes * shots))\n",
    "        temp_images = np.zeros(shape=(num_classes * shots, 28, 28, 1))\n",
    "        if split:\n",
    "            test_labels = np.zeros(shape=(num_classes))\n",
    "            test_images = np.zeros(shape=(num_classes, 28, 28, 1))\n",
    "\n",
    "        # Get a random subset of labels from the entire label set.\n",
    "        label_subset = random.choices(self.labels, k=num_classes)\n",
    "        for class_idx, class_obj in enumerate(label_subset):\n",
    "            # Use enumerated index value as a temporary label for mini-batch in\n",
    "            # few shot learning.\n",
    "            temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx\n",
    "            # If creating a split dataset for testing, select an extra sample from each\n",
    "            # label to create the test dataset.\n",
    "            if split:\n",
    "                test_labels[class_idx] = class_idx\n",
    "                images_to_split = random.choices(\n",
    "                    self.data[label_subset[class_idx]], k=shots + 1\n",
    "                )\n",
    "                test_images[class_idx] = images_to_split[-1]\n",
    "                temp_images[\n",
    "                    class_idx * shots : (class_idx + 1) * shots\n",
    "                ] = images_to_split[:-1]\n",
    "            else:\n",
    "                # For each index in the randomly selected label_subset, sample the\n",
    "                # necessary number of images.\n",
    "                temp_images[\n",
    "                    class_idx * shots : (class_idx + 1) * shots\n",
    "                ] = random.choices(self.data[label_subset[class_idx]], k=shots)\n",
    "\n",
    "        dataset = tf.data.Dataset.from_tensor_slices(\n",
    "            (temp_images.astype(np.float32), temp_labels.astype(np.int32))\n",
    "        )\n",
    "        dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)\n",
    "        if split:\n",
    "            return dataset, test_images, test_labels\n",
    "        return dataset\n",
    "\n",
    "\n",
    "import urllib3\n",
    "\n",
    "urllib3.disable_warnings()  # Disable SSL warnings that may happen during download.\n",
    "train_dataset = Dataset(training=True)\n",
    "test_dataset = Dataset(training=False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Visualize some examples from the dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "_, axarr = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))\n",
    "\n",
    "sample_keys = list(train_dataset.data.keys())\n",
    "\n",
    "for a in range(5):\n",
    "    for b in range(5):\n",
    "        temp_image = train_dataset.data[sample_keys[a]][b]\n",
    "        temp_image = np.stack((temp_image[:, :, 0],) * 3, axis=2)\n",
    "        temp_image *= 255\n",
    "        temp_image = np.clip(temp_image, 0, 255).astype(\"uint8\")\n",
    "        if b == 2:\n",
    "            axarr[a, b].set_title(\"Class : \" + sample_keys[a])\n",
    "        axarr[a, b].imshow(temp_image, cmap=\"gray\")\n",
    "        axarr[a, b].xaxis.set_visible(False)\n",
    "        axarr[a, b].yaxis.set_visible(False)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Build the model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "def conv_bn(x):\n",
    "    x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding=\"same\")(x)\n",
    "    x = layers.BatchNormalization()(x)\n",
    "    return layers.ReLU()(x)\n",
    "\n",
    "\n",
    "inputs = layers.Input(shape=(28, 28, 1))\n",
    "x = conv_bn(inputs)\n",
    "x = conv_bn(x)\n",
    "x = conv_bn(x)\n",
    "x = conv_bn(x)\n",
    "x = layers.Flatten()(x)\n",
    "outputs = layers.Dense(classes, activation=\"softmax\")(x)\n",
    "model = keras.Model(inputs=inputs, outputs=outputs)\n",
    "model.compile()\n",
    "optimizer = keras.optimizers.SGD(learning_rate=learning_rate)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train the model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "training = []\n",
    "testing = []\n",
    "for meta_iter in range(meta_iters):\n",
    "    frac_done = meta_iter / meta_iters\n",
    "    cur_meta_step_size = (1 - frac_done) * meta_step_size\n",
    "    # Temporarily save the weights from the model.\n",
    "    old_vars = model.get_weights()\n",
    "    # Get a sample from the full dataset.\n",
    "    mini_dataset = train_dataset.get_mini_dataset(\n",
    "        inner_batch_size, inner_iters, train_shots, classes\n",
    "    )\n",
    "    for images, labels in mini_dataset:\n",
    "        with tf.GradientTape() as tape:\n",
    "            preds = model(images)\n",
    "            loss = keras.losses.sparse_categorical_crossentropy(labels, preds)\n",
    "        grads = tape.gradient(loss, model.trainable_weights)\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "    new_vars = model.get_weights()\n",
    "    # Perform SGD for the meta step.\n",
    "    for var in range(len(new_vars)):\n",
    "        new_vars[var] = old_vars[var] + (\n",
    "            (new_vars[var] - old_vars[var]) * cur_meta_step_size\n",
    "        )\n",
    "    # After the meta-learning step, reload the newly-trained weights into the model.\n",
    "    model.set_weights(new_vars)\n",
    "    # Evaluation loop\n",
    "    if meta_iter % eval_interval == 0:\n",
    "        accuracies = []\n",
    "        for dataset in (train_dataset, test_dataset):\n",
    "            # Sample a mini dataset from the full dataset.\n",
    "            train_set, test_images, test_labels = dataset.get_mini_dataset(\n",
    "                eval_batch_size, eval_iters, shots, classes, split=True\n",
    "            )\n",
    "            old_vars = model.get_weights()\n",
    "            # Train on the samples and get the resulting accuracies.\n",
    "            for images, labels in train_set:\n",
    "                with tf.GradientTape() as tape:\n",
    "                    preds = model(images)\n",
    "                    loss = keras.losses.sparse_categorical_crossentropy(labels, preds)\n",
    "                grads = tape.gradient(loss, model.trainable_weights)\n",
    "                optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "            test_preds = model.predict(test_images)\n",
    "            test_preds = tf.argmax(test_preds).numpy()\n",
    "            num_correct = (test_preds == test_labels).sum()\n",
    "            # Reset the weights after getting the evaluation accuracies.\n",
    "            model.set_weights(old_vars)\n",
    "            accuracies.append(num_correct / classes)\n",
    "        training.append(accuracies[0])\n",
    "        testing.append(accuracies[1])\n",
    "        if meta_iter % 100 == 0:\n",
    "            print(\n",
    "                \"batch %d: train=%f test=%f\" % (meta_iter, accuracies[0], accuracies[1])\n",
    "            )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Visualize Results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# First, some preprocessing to smooth the training and testing arrays for display.\n",
    "window_length = 100\n",
    "train_s = np.r_[\n",
    "    training[window_length - 1 : 0 : -1], training, training[-1:-window_length:-1]\n",
    "]\n",
    "test_s = np.r_[\n",
    "    testing[window_length - 1 : 0 : -1], testing, testing[-1:-window_length:-1]\n",
    "]\n",
    "w = np.hamming(window_length)\n",
    "train_y = np.convolve(w / w.sum(), train_s, mode=\"valid\")\n",
    "test_y = np.convolve(w / w.sum(), test_s, mode=\"valid\")\n",
    "\n",
    "# Display the training accuracies.\n",
    "x = np.arange(0, len(test_y), 1)\n",
    "plt.plot(x, test_y, x, train_y)\n",
    "plt.legend([\"test\", \"train\"])\n",
    "plt.grid()\n",
    "\n",
    "train_set, test_images, test_labels = dataset.get_mini_dataset(\n",
    "    eval_batch_size, eval_iters, shots, classes, split=True\n",
    ")\n",
    "for images, labels in train_set:\n",
    "    with tf.GradientTape() as tape:\n",
    "        preds = model(images)\n",
    "        loss = keras.losses.sparse_categorical_crossentropy(labels, preds)\n",
    "    grads = tape.gradient(loss, model.trainable_weights)\n",
    "    optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "test_preds = model.predict(test_images)\n",
    "test_preds = tf.argmax(test_preds).numpy()\n",
    "\n",
    "_, axarr = plt.subplots(nrows=1, ncols=5, figsize=(20, 20))\n",
    "\n",
    "sample_keys = list(train_dataset.data.keys())\n",
    "\n",
    "for i, ax in zip(range(5), axarr):\n",
    "    temp_image = np.stack((test_images[i, :, :, 0],) * 3, axis=2)\n",
    "    temp_image *= 255\n",
    "    temp_image = np.clip(temp_image, 0, 255).astype(\"uint8\")\n",
    "    ax.set_title(\n",
    "        \"Label : {}, Prediction : {}\".format(int(test_labels[i]), test_preds[i])\n",
    "    )\n",
    "    ax.imshow(temp_image, cmap=\"gray\")\n",
    "    ax.xaxis.set_visible(False)\n",
    "    ax.yaxis.set_visible(False)\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "reptile",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}